"""
Phase 6: Financial center robustness check.

Re-estimates Models 2b and 2c excluding pairs involving financial center
jurisdictions (LUX, IRL, NLD, CYM, BMU, BHS, PAN, HKG, SGP, CHE, etc.)
to ensure results are not driven by portfolio transit hubs.

Output: gravity_bilateral/output/tables/financial_center_robustness.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/gravity_bilateral")))
from src.model import PanelGLS

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"

# Financial centers / portfolio transit hubs
# Narrow list: jurisdictions widely recognized as offshore or pass-through
FINANCIAL_CENTERS_NARROW = {
    'LUX', 'IRL', 'CYM', 'BMU', 'BHS', 'PAN', 'VGB', 'BHR', 'MUS', 'MLT', 'CYP'
}
# Broad list: adds major financial hubs with large transit positions
FINANCIAL_CENTERS_BROAD = FINANCIAL_CENTERS_NARROW | {
    'HKG', 'SGP', 'CHE', 'NLD', 'BEL', 'GBR'
}


def estimate_model(df, dep_var, regressors, model_name):
    """Estimate a gravity model and return results."""
    years = sorted(df['year'].dropna().unique())
    yr_cols = [f'yr_{int(y)}' for y in years[1:]]
    yr_cols = [c for c in yr_cols if c in df.columns]
    all_vars = regressors + yr_cols

    est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year']).copy()
    if len(est) < 100:
        print(f"  {model_name}: insufficient obs ({len(est)})")
        return None

    y = est[dep_var].values
    X = est[all_vars].values

    gls = PanelGLS()
    gls.fit(y, X, est['pair_id'].values, est['year'].values.astype(int))

    results = []
    for i, v in enumerate(regressors):
        results.append({
            'model': model_name,
            'variable': v,
            'coefficient': gls.beta[i],
            'std_error': gls.se[i],
            't_stat': gls.tvalues[i],
            'p_value': gls.pvalues[i],
        })
    for meta_var, meta_val in [('_R_squared', gls.r_squared),
                                ('_N_obs', gls.n_obs),
                                ('_N_pairs', gls.n_countries)]:
        results.append({
            'model': model_name,
            'variable': meta_var,
            'coefficient': meta_val,
            'std_error': np.nan, 't_stat': np.nan, 'p_value': np.nan,
        })
    return pd.DataFrame(results)


def main():
    print("=" * 70)
    print("PHASE 6: FINANCIAL CENTER ROBUSTNESS")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"Full panel: {len(df):,} obs")

    # Create year dummies
    years = sorted(df['year'].dropna().unique())
    for y in years[1:]:
        df[f'yr_{int(y)}'] = (df['year'] == y).astype(int)

    dep_var = 'log_portfolio_total'
    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']
    kaopen_vars = ['kaopen_j', 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']

    all_results = []

    # Full sample baselines (for comparison)
    print("\n--- Full Sample ---")
    res = estimate_model(df, dep_var, gravity_vars + demo_vars, "2b: Full sample")
    if res is not None:
        all_results.append(res)

    res = estimate_model(df, dep_var, gravity_vars + demo_vars + kaopen_vars,
                         "2c: Full sample")
    if res is not None:
        all_results.append(res)

    # Narrow exclusion
    for label, fc_set in [("Narrow FC exclusion", FINANCIAL_CENTERS_NARROW),
                          ("Broad FC exclusion", FINANCIAL_CENTERS_BROAD)]:
        mask = ~(df['reporter'].isin(fc_set) | df['partner'].isin(fc_set))
        df_excl = df[mask].copy()
        n_dropped = len(df) - len(df_excl)
        print(f"\n--- {label} ({len(fc_set)} countries) ---")
        print(f"  Dropped: {n_dropped:,} obs ({n_dropped/len(df)*100:.1f}%)")
        print(f"  Remaining: {len(df_excl):,} obs")

        res = estimate_model(df_excl, dep_var, gravity_vars + demo_vars,
                             f"2b: {label}")
        if res is not None:
            all_results.append(res)

        res = estimate_model(df_excl, dep_var, gravity_vars + demo_vars + kaopen_vars,
                             f"2c: {label}")
        if res is not None:
            all_results.append(res)

    # Combine and save
    results_df = pd.concat(all_results, ignore_index=True)
    outfile = OUTPUT_DIR / "financial_center_robustness.csv"
    results_df.to_csv(outfile, index=False)
    print(f"\nSaved: {outfile}")

    # Summary comparison
    print(f"\n{'=' * 70}")
    print("COEFFICIENT COMPARISON: dZ_1")
    print(f"{'=' * 70}")
    print(f"  {'Model':<40} {'Coef':>8} {'p-val':>8} {'% chg':>8}")
    print(f"  {'-' * 66}")

    for model_type in ['2b', '2c']:
        base_row = results_df[(results_df['model'].str.startswith(model_type)) &
                              (results_df['model'].str.contains('Full')) &
                              (results_df['variable'] == 'dZ_1')]
        if len(base_row) == 0:
            continue
        base_coef = base_row.iloc[0]['coefficient']

        for _, row in results_df[(results_df['variable'] == 'dZ_1') &
                                  (results_df['model'].str.startswith(model_type))].iterrows():
            pct = (row['coefficient'] - base_coef) / abs(base_coef) * 100
            sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
            print(f"  {row['model']:<40} {row['coefficient']:>8.3f} {row['p_value']:>8.4f} {pct:>+7.1f}%  {sig}")

    # Also show KAOPEN interaction stability
    print(f"\n{'=' * 70}")
    print("KAOPEN INTERACTION COMPARISON: dZ_1_x_kaopen_j")
    print(f"{'=' * 70}")
    for _, row in results_df[(results_df['variable'] == 'dZ_1_x_kaopen_j')].iterrows():
        sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
        print(f"  {row['model']:<40} {row['coefficient']:>8.3f} {row['p_value']:>8.4f}  {sig}")


if __name__ == "__main__":
    main()
